{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training a Neural Network using Augmentor and Keras\n",
    "\n",
    "In this notebook, we will train a simple convolutional neural network on the MNIST dataset using Augmentor to augment images on the fly using a generator.\n",
    "\n",
    "## Import Required Libraries\n",
    "\n",
    "We start by making a number of imports:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, sys\n",
    "sys.path.append('..')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/anaconda3/envs/pyqae_base/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
      "  from ._conv import register_converters as _register_converters\n",
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "import Augmentor\n",
    "\n",
    "import keras\n",
    "from keras.models import Sequential\n",
    "from keras.layers import Dense, Dropout, Flatten\n",
    "from keras.layers import Conv2D, MaxPooling2D\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define a Convolutional Neural Network\n",
    "\n",
    "Once the libraries have been imported, we define a small convolutional neural network. See the Keras documentation for details of this network: <https://github.com/fchollet/keras/blob/master/examples/mnist_cnn.py> \n",
    "\n",
    "It is a three layer deep neural network, consisting of 2 convolutional layers and a fully connected layer:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_classes = 10\n",
    "input_shape = (28, 28, 1)\n",
    "\n",
    "model = Sequential()\n",
    "model.add(Conv2D(32, kernel_size=(3, 3),\n",
    "                 activation='relu',\n",
    "                 input_shape=input_shape))\n",
    "model.add(Conv2D(64, (3, 3), activation='relu'))\n",
    "model.add(MaxPooling2D(pool_size=(2, 2)))\n",
    "model.add(Dropout(0.25))\n",
    "model.add(Flatten())\n",
    "model.add(Dense(128, activation='relu'))\n",
    "model.add(Dropout(0.5))\n",
    "model.add(Dense(num_classes, activation='softmax'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Once a network has been defined, you can compile it so that the model is ready to be trained with data:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.compile(loss=keras.losses.categorical_crossentropy,\n",
    "              optimizer=keras.optimizers.Adadelta(),\n",
    "              metrics=['accuracy'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can view a summary of the network using the `summary()` function:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "conv2d_1 (Conv2D)            (None, 26, 26, 32)        320       \n",
      "_________________________________________________________________\n",
      "conv2d_2 (Conv2D)            (None, 24, 24, 64)        18496     \n",
      "_________________________________________________________________\n",
      "max_pooling2d_1 (MaxPooling2 (None, 12, 12, 64)        0         \n",
      "_________________________________________________________________\n",
      "dropout_1 (Dropout)          (None, 12, 12, 64)        0         \n",
      "_________________________________________________________________\n",
      "flatten_1 (Flatten)          (None, 9216)              0         \n",
      "_________________________________________________________________\n",
      "dense_1 (Dense)              (None, 128)               1179776   \n",
      "_________________________________________________________________\n",
      "dropout_2 (Dropout)          (None, 128)               0         \n",
      "_________________________________________________________________\n",
      "dense_2 (Dense)              (None, 10)                1290      \n",
      "=================================================================\n",
      "Total params: 1,199,882\n",
      "Trainable params: 1,199,882\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Use Augmentor with a DataFrame\n",
    "\n",
    "Now we will use Augmentor from a DataFrame to train the model. We can create a DataFrame from the download data using the glob command and then parsing pieces of the path\n",
    "\n",
    "\n",
    "To get the data, we can use `wget` (this may not work under Windows):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "if not os.path.exists('mnist_png'):\n",
    "    !wget https://rawgit.com/myleott/mnist_png/master/mnist_png.tar.gz\n",
    "    !tar -xf mnist_png.tar.gz"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After the MNIST data has downloaded, we can instantiate a `Pipeline` object in the `training` directory to add the images to the current pipeline:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>path</th>\n",
       "      <th>data_split</th>\n",
       "      <th>mnist_cat</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>mnist_png/training/9/36655.png</td>\n",
       "      <td>training</td>\n",
       "      <td>9</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>mnist_png/training/9/32433.png</td>\n",
       "      <td>training</td>\n",
       "      <td>9</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>mnist_png/training/9/28319.png</td>\n",
       "      <td>training</td>\n",
       "      <td>9</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>mnist_png/training/9/4968.png</td>\n",
       "      <td>training</td>\n",
       "      <td>9</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>mnist_png/training/9/23502.png</td>\n",
       "      <td>training</td>\n",
       "      <td>9</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                             path data_split mnist_cat\n",
       "0  mnist_png/training/9/36655.png   training         9\n",
       "1  mnist_png/training/9/32433.png   training         9\n",
       "2  mnist_png/training/9/28319.png   training         9\n",
       "3   mnist_png/training/9/4968.png   training         9\n",
       "4  mnist_png/training/9/23502.png   training         9"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from glob import glob\n",
    "import pandas as pd\n",
    "import os\n",
    "image_df = pd.DataFrame(dict(path = glob('mnist_png/*/*/*.png')))\n",
    "image_df['data_split'] = image_df['path'].map(lambda x: x.split('/')[-3])\n",
    "image_df['mnist_cat'] = image_df['path'].map(lambda x: x.split('/')[-2])\n",
    "image_df.head(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Initialised with 60000 image(s) found.\n",
      "Output directory set to output."
     ]
    }
   ],
   "source": [
    "p = Augmentor.DataFramePipeline(image_df.query('data_split==\"training\"'), \n",
    "                                image_col = 'path', \n",
    "                                category_col = 'mnist_cat')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Add Operations to the Pipeline\n",
    "\n",
    "Now that a pipeline object `p` has been created, we can add operations to the pipeline. Below we add several simple  operations:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "p.flip_top_bottom(probability=0.1)\n",
    "p.rotate(probability=0.3, max_left_rotation=5, max_right_rotation=5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can view the status of pipeline using the `status()` function, which shows information regarding the number of classes in the pipeline, the number of images, and what operations have been added to the pipeline:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Operations: 2\n",
      "\t0: Flip (probability=0.1 top_bottom_left_right=TOP_BOTTOM )\n",
      "\t1: RotateRange (probability=0.3 max_left_rotation=-5 max_right_rotation=5 )\n",
      "Images: 60000\n",
      "Classes: 10\n",
      "\tClass index: 0 Class label: 0 \n",
      "\tClass index: 1 Class label: 1 \n",
      "\tClass index: 2 Class label: 2 \n",
      "\tClass index: 3 Class label: 3 \n",
      "\tClass index: 4 Class label: 4 \n",
      "\tClass index: 5 Class label: 5 \n",
      "\tClass index: 6 Class label: 6 \n",
      "\tClass index: 7 Class label: 7 \n",
      "\tClass index: 8 Class label: 8 \n",
      "\tClass index: 9 Class label: 9 \n",
      "Dimensions: 1\n",
      "\tWidth: 28 Height: 28\n",
      "Formats: 1\n",
      "\t PNG\n",
      "\n",
      "You can remove operations using the appropriate index and the remove_operation(index) function.\n"
     ]
    }
   ],
   "source": [
    "p.status()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Creating a Generator\n",
    "\n",
    "A generator will create images indefinitely, and we can use this generator as input into the model created above. The generator is created with a user-defined batch size, which we define here in a variable named `batch_size`. This is used later to define number of steps per epoch, so it is best to keep it stored as a variable."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 128\n",
    "g = p.keras_generator(batch_size=batch_size)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The generator can now be used to created augmented data. In Python, generators are invoked using the `next()` function - the Augmentor generators will return images indefinitely, and so `next()` can be called as often as required. \n",
    "\n",
    "You can view the output of generator manually:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "images, labels = next(g)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Images, and their labels, are returned in batches of the size defined above by `batch_size`. The `image_batch` variable is a tuple, containing the augmentented images and their corresponding labels.\n",
    "\n",
    "To see the label of the first image returned by the generator you can use the array's index:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0 0 0 0 0 0 0 1 0 0]\n"
     ]
    }
   ],
   "source": [
    "print(labels[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Or preview the images using Matplotlib (the image should be a 5, according to the label information above):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAADm5JREFUeJzt3X+sVPWZx/HPA4IoENQwWCK4lxKz0RCXrhMwcbO6VBtqmiAxNcWkuWsIaKxmmzS6SEwgmkUltqx/rDWXhRRiy49IUTS4W39s1CabxsGYKsvuVuVui1wul2DsbVR+PvvHPZhbvPOdYebMnLk871di7sx5zvfOw8TPPTPzPWe+5u4CEM+YohsAUAzCDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gqAva+WBTp071rq6udj4kEEpvb6+OHDli9ezbVPjNbKGkpySNlfSv7v54av+uri5VKpVmHhJAQrlcrnvfhl/2m9lYSf8i6duSrpG0xMyuafT3AWivZt7zz5P0gbt/5O7HJW2VtCiftgC0WjPhv0LSH4bdP5Bt+zNmttzMKmZWGRgYaOLhAOSpmfCP9KHCV64Pdvcedy+7e7lUKjXxcADy1Ez4D0iaOez+DEkHm2sHQLs0E/63JV1lZrPMbLyk70nalU9bAFqt4ak+dz9pZvdJ+ncNTfVtdPe9uXUGoKWamud3992SdufUC4A24vReICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgCD8QFOEHgmpqlV4z65U0KOmUpJPuXs6jKQCt11T4M3/n7kdy+D0A2oiX/UBQzYbfJf3KzPaY2fI8GgLQHs2+7L/B3Q+a2TRJr5jZf7v7m8N3yP4oLJekK6+8ssmHA5CXpo787n4w+3lY0k5J80bYp8fdy+5eLpVKzTwcgBw1HH4zm2hmk8/clvQtSe/n1RiA1mrmZf/lknaa2Znf8wt3/7dcugLQcg2H390/kvRXOfYCoI2Y6gOCIvxAUIQfCIrwA0ERfiAowg8ElcdVfUBVH3/8cdXawMBAcuzcuXPzbudLa9euTda3bt2arL/++uvJ+iWXXHLOPbUbR34gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCIp5fiQdO3YsWV+2bFmy/txzz1Wtff7558mxDz74YLK+aNGiZP2ZZ56pWtuxY0dy7IQJE5L1MWNG/3Fz9P8LADSE8ANBEX4gKMIPBEX4gaAIPxAU4QeCYp7/PODuVWvZugpVDQ4OJut33XVXsv78888n62vWrKlaO3ToUHLspk2bkvVa1+Q3Y/Lkycn6BReM/uhw5AeCIvxAUIQfCIrwA0ERfiAowg8ERfiBoGpOVprZRknfkXTY3edk2y6TtE1Sl6ReSXe4+yeta/P8durUqWT96NGjDY8fO3ZscuyCBQuS9SlTpiTrH374YbI+Y8aMqrXU+QmS9Oijjybr+/fvT9a3bdtWtfbkk08mx65evTpZHz9+fLI+GtRz5P+ZpIVnbVsh6TV3v0rSa9l9AKNIzfC7+5uSzj70LJJ05vSrTZJuy7kvAC3W6Hv+y929T5Kyn9PyawlAO7T8Az8zW25mFTOr1FqbDUD7NBr+fjObLknZz8PVdnT3Hncvu3u5VCo1+HAA8tZo+HdJ6s5ud0t6IZ92ALRLzfCb2RZJ/ynpL83sgJktlfS4pFvM7HeSbsnuAxhFas7zu/uSKqVv5txLWM8++2yyvnTp0mQ9Nc9/0UUXJcc+8MADyfrDDz+crI8bNy5Zb0ata+a7urqS9c2bN1etzZkzJzm2u7s7Wed6fgCjFuEHgiL8QFCEHwiK8ANBEX4gqNE/XzEK1DqtecWK9EWRp0+fTtZTy0XPnz8/Ofb+++9P1ls5lVdLrX/3Qw89lKynLoXevn17cmytKdLzAUd+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiKef42ePrpp5P1/v7+pn7/rbfeWrW2fv365NipU6c29djNqPXV3Vu2bEnWN2zYkKw/8cQTVWvlcjk5NgKO/EBQhB8IivADQRF+ICjCDwRF+IGgCD8QFPP8Ofj000+T9XXr1jX1+2fNmpWsp5aivvjii5t67Fb65JP0qu7Lli1L1u+5556G67WWLo+AIz8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBFVznt/MNkr6jqTD7j4n27Za0jJJZ76QfqW7725Vk50g9R3ya9euTY6tdR7A1Vdfnazv2bMnWe/k75g/efJk1dqNN96YHDt+/Phk/ZFHHknWi1xzYDSo58j/M0kLR9i+zt3nZv+d18EHzkc1w+/ub0qqvvQJgFGpmff895nZb81so5ldmltHANqi0fD/VNJsSXMl9Un6cbUdzWy5mVXMrFJrzToA7dNQ+N29391PuftpSeslzUvs2+PuZXcvl0qlRvsEkLOGwm9m04fdXSzp/XzaAdAu9Uz1bZF0k6SpZnZA0ipJN5nZXEkuqVfS3S3sEUAL1Ay/uy8ZYXP6C9PPQydOnKhae/nll5Njr7/++mT91VdfTdY7eR7/+PHjyfrtt99etXbo0KHk2FrnN0yaNClZRxpn+AFBEX4gKMIPBEX4gaAIPxAU4QeC4qu765S6PLTWMtgTJ05sql6kWsto7927N1l/6aWXqtbWrFmTHDt79uxkHc3hyA8ERfiBoAg/EBThB4Ii/EBQhB8IivADQTHPX6cxY6r/nbzuuuva2Em+as3j79+/P1lfsGBBsj5z5syqtXvvvTc5Fq3FkR8IivADQRF+ICjCDwRF+IGgCD8QFOEHgmKe/zxXax6/r68vWZ8/f36yfuGFFybrb7zxRtXalClTkmPRWhz5gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiComvP8ZjZT0mZJX5N0WlKPuz9lZpdJ2iapS1KvpDvc/ZPWtYpGHD16NFm/9tprk/Vjx44l6zt37kzWZ82alayjOPUc+U9K+pG7Xy3pekk/MLNrJK2Q9Jq7XyXptew+gFGiZvjdvc/d38luD0raJ+kKSYskbcp22yTptlY1CSB/5/Se38y6JH1D0m8kXe7ufdLQHwhJ0/JuDkDr1B1+M5skaYekH7r7H89h3HIzq5hZZWBgoJEeAbRAXeE3s3EaCv7P3f2X2eZ+M5ue1adLOjzSWHfvcfeyu5dLpVIePQPIQc3wm5lJ2iBpn7v/ZFhpl6Tu7Ha3pBfybw9Aq9RzSe8Nkr4v6T0zezfbtlLS45K2m9lSSb+X9N3WtIhaBgcHq9YWL17c8FhJ6unpSdZvvvnmZB2dq2b43f3XkqxK+Zv5tgOgXTjDDwiK8ANBEX4gKMIPBEX4gaAIPxAUX909Cnz22WfJ+qpVq6rW3nrrreTYxx57LFnv7u5O1jF6ceQHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaCY5+8AJ06cSNZ3796drK9bt65qbeHChcmxK1bwpctRceQHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaCY528Dd0/W9+7dm6zffffdyfqdd95ZtbZhw4bkWMTFkR8IivADQRF+ICjCDwRF+IGgCD8QFOEHgqo5z29mMyVtlvQ1Sacl9bj7U2a2WtIySQPZrivdPX3heVD9/f3J+uLFi5P1adOmJes9PT1VaxMmTEiORVz1nORzUtKP3P0dM5ssaY+ZvZLV1rn7k61rD0Cr1Ay/u/dJ6stuD5rZPklXtLoxAK11Tu/5zaxL0jck/SbbdJ+Z/dbMNprZpVXGLDeziplVBgYGRtoFQAHqDr+ZTZK0Q9IP3f2Pkn4qabakuRp6ZfDjkca5e4+7l929XCqVcmgZQB7qCr+ZjdNQ8H/u7r+UJHfvd/dT7n5a0npJ81rXJoC81Qy/mZmkDZL2uftPhm2fPmy3xZLez789AK1Sz6f9N0j6vqT3zOzdbNtKSUvMbK4kl9QrKX3d6Xns+PHjyfrWrVuT9S+++CJZf/HFF5P1iRMnJuvASOr5tP/XkmyEEnP6wCjGGX5AUIQfCIrwA0ERfiAowg8ERfiBoKzW10rnqVwue6VSadvjAdGUy2VVKpWRpua/giM/EBThB4Ii/EBQhB8IivADQRF+ICjCDwTV1nl+MxuQ9H/DNk2VdKRtDZybTu2tU/uS6K1Refb2F+5e1/fltTX8X3lws4q7lwtrIKFTe+vUviR6a1RRvfGyHwiK8ANBFR3+6utMFa9Te+vUviR6a1QhvRX6nh9AcYo+8gMoSCHhN7OFZvY/ZvaBma0ooodqzKzXzN4zs3fNrNDrj7Nl0A6b2fvDtl1mZq+Y2e+ynyMuk1ZQb6vN7OPsuXvXzG4tqLeZZvYfZrbPzPaa2T9k2wt97hJ9FfK8tf1lv5mNlfS/km6RdEDS25KWuPt/tbWRKsysV1LZ3QufEzazv5X0J0mb3X1Otm2tpKPu/nj2h/NSd//HDulttaQ/Fb1yc7agzPThK0tLuk3S36vA5y7R1x0q4Hkr4sg/T9IH7v6Rux+XtFXSogL66Hju/qako2dtXiRpU3Z7k4b+52m7Kr11BHfvc/d3stuDks6sLF3oc5foqxBFhP8KSX8Ydv+AOmvJb5f0KzPbY2bLi25mBJdny6afWT59WsH9nK3mys3tdNbK0h3z3DWy4nXeigj/SF8x1ElTDje4+19L+rakH2Qvb1GfulZubpcRVpbuCI2ueJ23IsJ/QNLMYfdnSDpYQB8jcveD2c/Dknaq81Yf7j+zSGr283DB/Xypk1ZuHmllaXXAc9dJK14XEf63JV1lZrPMbLyk70naVUAfX2FmE7MPYmRmEyV9S523+vAuSd3Z7W5JLxTYy5/plJWbq60srYKfu05b8bqQk3yyqYx/ljRW0kZ3/6e2NzECM/u6ho720tAipr8osjcz2yLpJg1d9dUvaZWk5yVtl3SlpN9L+q67t/2Dtyq93aShl65frtx85j12m3v7G0lvSXpP0uls80oNvb8u7LlL9LVEBTxvnOEHBMUZfkBQhB8IivADQRF+ICjCDwRF+IGgCD8QFOEHgvp/c4EhwmoWEhkAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.imshow(images[0].reshape(28, 28), cmap=\"Greys\");"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train the Network\n",
    "\n",
    "We train the network by passing the generator, `g`, to the model's fit function. In Keras, if a generator is used we used the `fit_generator()` function as opposed to the standard `fit()` function. Also, the steps per epoch should roughly equal the total number of images in your dataset divided by the `batch_size`.\n",
    "\n",
    "Training the network over 5 epochs, we get the following output:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/5\n",
      "156/468 [========>.....................] - ETA: 3:55 - loss: 0.7015 - acc: 0.7725"
     ]
    }
   ],
   "source": [
    "h = model.fit_generator(g, steps_per_epoch=len(p.augmentor_images)/batch_size, epochs=5, verbose=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Summary\n",
    "\n",
    "Using Augmentor with Keras means only that you need to create a generator when you are finished creating your pipeline. This has the advantage that no images need to be saved to disk and are augmented on the fly."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:pyqae_base]",
   "language": "python",
   "name": "conda-env-pyqae_base-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
